{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.Change mode  to test mode \n",
    "```\n",
    "commands add \n",
    "--phase-test True\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.Verify the performance of the model in the test dataset \n",
    "Run the following command, you can see the performance of the model in the test set in the ./submission/ directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!python main.py --config=\"cfgs/fishnet150-32.yaml\" --resume ./checkpoints/fishnet150_bs32/_15.pth.tar  --phase-test True --val True --val-save True --gpus 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!python main.py --config=\"cfgs/fishnet150-32.yaml\" --resume ./checkpoints/fishnet150_bs32/_51_best.pth.tar --phase-test True --val True --val-save True --gpus 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!python main.py --config=\"cfgs/fishnet150-32.yaml\" --resume ./checkpoints/fishnet150_bs32/_26_best.pth.tar --phase-test True --val True --val-save True "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "!python main.py --config=\"cfgs/FeatherNet54-32.yaml\" --resume ./checkpoints/FeatherNet54/_40_best.pth.tar --phase-test True --val True --val-save True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!python main.py --config=\"cfgs/FeatherNet54-se-64.yaml\" --resume ./checkpoints/FeatherNet54-se/_68_best.pth.tar --phase-test True --val True --val-save True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "!python main.py --config=\"cfgs/mobilenetv2.yaml\" --resume ./checkpoints/mobilenetv2_bs32/_4_best.pth.tar --phase-test True --val True --val-save True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!python main.py --config=\"cfgs/mobilenetv2.yaml\" --resume ./checkpoints/mobilenetv2_bs32/_5.pth.tar --phase-test True --val True --val-save True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!python main.py --config=\"cfgs/mobilenetv2.yaml\" --resume ./checkpoints/mobilenetv2_bs32/_6.pth.tar --phase-test True --val True --val-save True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!python main.py --config=\"cfgs/MobileLiteNetA-32.yaml\" --resume ./checkpoints/mobilelitenetA_bs32/_50_best.pth.tar --phase-test True --val True --val-save True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "!python main.py --config=\"cfgs/MobileLiteNetB-32.yaml\" --resume ./checkpoints/mobilelitenetB_bs32/_47_best.pth.tar --phase-test True --val True --val-save True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Test with IR "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "!python main.py --config=\"cfgs/MobileLiteNetB-32.yaml\" --resume ./checkpoints/4mobilelitenetB_bs32-ir/_54.pth.tar --phase-test True --val True --val-save True --phase-ir 3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.split the predicted scores from each submission file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def splitscore(file_dir):\n",
    "    score = []\n",
    "    Prefix_str = []\n",
    "    f = open(file_dir)\n",
    "    for line in f:\n",
    "        s =line.split()\n",
    "        score.append(float(s[-1]))\n",
    "        s = s[0] + ' ' + s[1] + ' ' + s[2] + ' '\n",
    "        Prefix_str.append(s)\n",
    "    return score,Prefix_str"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.Collect all predicted scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from pathlib import Path  \n",
    "# specify directory that contains the prediction scores of all models\n",
    "data_dir = '/home/zp/disk1T/libxcam-testset/submission/checkpoints_testset_predictes'\n",
    "txt_files = [i for i in Path(data_dir).glob(\"**/*.txt\") ] \n",
    "\n",
    "file_dir1=txt_files[0]\n",
    "print('1',txt_files[0])\n",
    "score1,Prefix_str = splitscore(file_dir1)\n",
    "\n",
    "file_dir2 = txt_files[1]\n",
    "print('2',txt_files[1])\n",
    "score2,Prefix_str = splitscore(file_dir2)\n",
    "\n",
    "file_dir3 = txt_files[2]\n",
    "print('3',txt_files[2])\n",
    "score3,Prefix_str = splitscore(file_dir3)\n",
    "\n",
    "file_dir4 = txt_files[3]\n",
    "print('4',txt_files[3])\n",
    "score4,Prefix_str = splitscore(file_dir4)\n",
    "\n",
    "file_dir5 = txt_files[4]\n",
    "print('5',txt_files[4])\n",
    "score5,Prefix_str = splitscore(file_dir5)\n",
    "\n",
    "file_dir6 = txt_files[5]\n",
    "print('6',txt_files[5])\n",
    "score6,Prefix_str = splitscore(file_dir6)\n",
    "\n",
    "file_dir7 = txt_files[6]\n",
    "print('7',txt_files[6])\n",
    "score7,Prefix_str = splitscore(file_dir7)\n",
    "\n",
    "file_dir8 = txt_files[7]\n",
    "print('8',txt_files[7])\n",
    "score8,Prefix_str = splitscore(file_dir8)\n",
    "\n",
    "file_dir9 = txt_files[8]\n",
    "print('9',txt_files[8])\n",
    "score9,Prefix_str = splitscore(file_dir9)\n",
    "\n",
    "file_dir10 = txt_files[9]\n",
    "print('10',txt_files[9])\n",
    "score10,Prefix_str = splitscore(file_dir10)\n",
    "\n",
    "scores = [score1,score2,score3,score4,score5,score6,score7,score8,score9,score10]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import confusion_matrix\n",
    "import roc\n",
    "import numpy as np\n",
    "def get_ensembled_scores(threshold_min,threshold_max,ir_threshold, fish_threshold,scores,ir_score,fish_scores,factors):\n",
    "    ensembled_scores = []\n",
    "    for i in range(len(ir_score)):\n",
    "        line_socres = [scores[j][i] for j in range(10)]\n",
    "        mean_socre = np.array(line_socres).dot(np.array(factors)) / np.sum(np.array(factors))\n",
    "        if mean_socre < threshold_min:\n",
    "            ensembled_scores.append(min(line_socres))\n",
    "        elif (mean_socre > threshold_max):\n",
    "            ensembled_scores.append(max(line_socres))\n",
    "        else:\n",
    "            if fish_scores[i] < fish_threshold:\n",
    "                ensembled_scores.append(min(ir_score[i],min(line_socres),fish_scores[i]))\n",
    "            elif ir_score[i] < ir_threshold:\n",
    "                ensembled_scores.append(min(ir_score[i],min(line_socres),fish_scores[i]))\n",
    "            else:\n",
    "                if (np.sum(np.array(factors)) * mean_socre + fish_scores[i])/(np.sum(np.array(factors))+1) > 0.5:\n",
    "                    ensembled_scores.append(max(fish_scores[i],max(line_socres)))\n",
    "                else:\n",
    "                    ensembled_scores.append(min(ir_score[i],min(line_socres)))\n",
    "    return ensembled_scores\n",
    "\n",
    "def get_prediceted_labels(ensembled_score):\n",
    "    prediceted_label = []\n",
    "    for i in range(len(ensembled_score)):\n",
    "        if ensembled_score[i] > 0.5:\n",
    "            prediceted_label.append(1)\n",
    "        else:\n",
    "            prediceted_label.append(0)\n",
    "    return prediceted_label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "ir_predict_file = 'submission/IRpredicte57710_mobilelitenetB_55_submission.txt' # 对所有的IR图片进行一次前向预测得到的结果\n",
    "ir_scores,ir_Prefix_str = splitscore(ir_predict_file)\n",
    "\n",
    "\n",
    "fish150_51 = 'submission/checkpoints_testset_predictes/2019-03-08_19:01:11_fishnet150_52_submission.txt' \n",
    "fish150_scores,fish150_Prefix_str = splitscore(fish150_51)\n",
    "\n",
    "factors = [1, 0, 1, 1, 0, 1, 0, 0, 1, 1]\n",
    "threshold_min = 0.18\n",
    "threshold_max = 0.8\n",
    "ir_threshold = 0.54\n",
    "fish_threshold = 0.32\n",
    "ensembled_scores = get_ensembled_scores(threshold_min,threshold_max,ir_threshold, fish_threshold,scores,ir_scores,fish150_scores,factors)\n",
    "prediceted_labels = get_prediceted_labels(ensembled_scores)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#  save submission file "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_file_list = open('data/test_public_list.txt').readlines()\n",
    "submission_file = open('submission/submission.txt','w')\n",
    "for i in range(len(ensembled_scores)):\n",
    "    submission_file.write(test_file_list[i].rstrip('\\n') + str(ensembled_scores[i]) + '\\n')\n",
    "submission_file.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
